#
# Acknowledgement:
#   This code is adapted from LLaMA-Factory:
#   https://github.com/hiyouga/LLaMA-Factory/blob/1653c2243895cca69c6efecc158881ed6e7c4129/src/llmtuner/data/template.py
#   Much thanks to the authors of LLaMA-Factory for their great work!
#

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import tiktoken

from rtp_llm.frontend.tokenizer_factory.tokenizers import BaseTokenizer

logger = logging.getLogger()


@dataclass
class Template:

    prefix: List[Union[str, Dict[str, str]]]
    prompt: List[Union[str, Dict[str, str]]]
    system: str
    sep: List[Union[str, Dict[str, str]]]
    stop_words: List[str]
    use_history: bool
    efficient_eos: bool
    replace_eos: bool

    def encode_oneturn(
        self,
        tokenizer: "BaseTokenizer",
        query: str,
        resp: str,
        history: Optional[List[Tuple[str, str]]] = None,
        system: Optional[str] = None,
    ) -> Tuple[List[int], List[int]]:
        r"""
        Returns a single pair of token ids representing prompt and response respectively.
        """
        system, history = self._format(query, resp, history, system)
        encoded_pairs = self._encode(tokenizer, system, history)
        prompt_ids = []
        for query_ids, resp_ids in encoded_pairs[:-1]:
            prompt_ids = prompt_ids + query_ids + resp_ids
        prompt_ids = prompt_ids + encoded_pairs[-1][0]
        answer_ids = encoded_pairs[-1][1]
        return prompt_ids, answer_ids

    def encode_multiturn(
        self,
        tokenizer: "BaseTokenizer",
        query: str,
        resp: str,
        history: Optional[List[Tuple[str, str]]] = None,
        system: Optional[str] = None,
    ) -> List[Tuple[List[int], List[int]]]:
        r"""
        Returns multiple pairs of token ids representing prompts and responses respectively.
        """
        system, history = self._format(query, resp, history, system)
        encoded_pairs = self._encode(tokenizer, system, history)
        return encoded_pairs

    def _format(
        self,
        query: str,
        resp: str,
        history: Optional[List[Tuple[str, str]]] = None,
        system: Optional[str] = None,
    ) -> Tuple[str, List[Tuple[str, str]]]:
        r"""
        Aligns inputs to the standard format.
        """
        system = system or self.system  # use system if provided
        history = history if (history and self.use_history) else []
        history = history + [(query, resp)]
        return system, history

    def _get_special_ids(
        self, tokenizer: "BaseTokenizer"
    ) -> Tuple[List[int], List[int]]:
        if tokenizer.bos_token_id is not None and getattr(
            tokenizer, "add_bos_token", True
        ):
            bos_ids = [tokenizer.bos_token_id]
        else:  # baichuan, gpt2, qwen, yi models have no bos token
            bos_ids = []

        if tokenizer.eos_token_id is None:
            raise ValueError("EOS token is required.")

        if self.efficient_eos:
            eos_ids = []
        else:
            eos_ids = [tokenizer.eos_token_id]

        return bos_ids, eos_ids

    def _encode(
        self,
        tokenizer: "BaseTokenizer",
        system: str,
        history: List[Tuple[str, str]],
    ) -> List[Tuple[List[int], List[int]]]:
        r"""
        Encodes formatted inputs to pairs of token ids.
        Turn 0: bos + prefix + sep + query    resp + eos
        Turn t: sep + bos + query             resp + eos
        """
        bos_ids, eos_ids = self._get_special_ids(tokenizer)
        sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
        encoded_pairs = []
        for turn_idx, (query, resp) in enumerate(history):
            if turn_idx == 0:
                prefix_ids = self._convert_inputs_to_ids(
                    tokenizer, context=self.prefix, system=system
                )
                if len(prefix_ids) != 0:  # has prefix
                    prefix_ids = bos_ids + prefix_ids + sep_ids
                else:
                    prefix_ids = bos_ids
            else:
                prefix_ids = sep_ids + bos_ids

            query_ids = self._convert_inputs_to_ids(
                tokenizer, context=self.prompt, query=query, idx=str(turn_idx + 1)
            )
            resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
            encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
        return encoded_pairs

    def _convert_inputs_to_ids(
        self,
        tokenizer: "BaseTokenizer",
        context: List[Union[str, Dict[str, str]]],
        system: Optional[str] = None,
        query: Optional[str] = None,
        idx: Optional[str] = None,
    ) -> List[int]:
        r"""
        Converts context to token ids.
        """
        if isinstance(
            getattr(tokenizer, "tokenizer", None), tiktoken.Encoding
        ):  # for tiktoken tokenizer (Qwen)
            kwargs = dict(allowed_special="all")
        else:
            kwargs = dict(add_special_tokens=False)

        token_ids = []
        for elem in context:
            if isinstance(elem, str):
                elem = (
                    elem.replace("{{system}}", system, 1)
                    if system is not None
                    else elem
                )
                elem = (
                    elem.replace("{{query}}", query, 1) if query is not None else elem
                )
                elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
                if len(elem) != 0:
                    token_ids = token_ids + tokenizer.encode(elem, **kwargs)
            elif isinstance(elem, dict):
                token_ids = token_ids + [
                    tokenizer.convert_tokens_to_ids(elem.get("token"))
                ]
            else:
                raise ValueError(
                    "Input must be string or dict[str, str], got {}".format(type(elem))
                )

        return token_ids


@dataclass
class Llama2Template(Template):

    def _encode(
        self,
        tokenizer: "BaseTokenizer",
        system: str,
        history: List[Tuple[str, str]],
    ) -> List[Tuple[List[int], List[int]]]:
        r"""
        Encodes formatted inputs to pairs of token ids.
        Turn 0: bos + prefix + query    resp + eos
        Turn t: bos + query             resp + eos
        """
        bos_ids, eos_ids = self._get_special_ids(tokenizer)
        encoded_pairs = []
        for turn_idx, (query, resp) in enumerate(history):
            if turn_idx == 0:  # llama2 template has no sep_ids
                query = self.prefix[0].replace("{{system}}", system) + query
            query_ids = self._convert_inputs_to_ids(
                tokenizer, context=self.prompt, query=query
            )
            resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
            encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
        return encoded_pairs


templates: Dict[str, Template] = {}


def register_template(
    name: str,
    prefix: List[Union[str, Dict[str, str]]],
    prompt: List[Union[str, Dict[str, str]]],
    system: str,
    sep: List[Union[str, Dict[str, str]]],
    stop_words: Optional[List[str]] = [],
    use_history: Optional[bool] = True,
    efficient_eos: Optional[bool] = False,
    replace_eos: Optional[bool] = False,
) -> None:
    template_class = Llama2Template if name.startswith("llama2") else Template
    templates[name] = template_class(
        prefix=prefix,
        prompt=prompt,
        system=system,
        sep=sep,
        stop_words=stop_words,
        use_history=use_history,
        efficient_eos=efficient_eos,
        replace_eos=replace_eos,
    )


def get_template_and_fix_tokenizer(name: str, tokenizer: BaseTokenizer) -> Template:
    if tokenizer.eos_token_id is None:
        tokenizer.tokenizer.eos_token = "<|endoftext|>"
        logger.info("Add eos token: {}".format(tokenizer.eos_token))

    if tokenizer.pad_token_id is None:
        tokenizer.tokenizer.pad_token = tokenizer.eos_token
        logger.info("Add pad token: {}".format(tokenizer.pad_token))

    if name is None:  # for pre-training
        return None

    template = templates.get(name, None)
    assert template is not None, "Template {} does not exist.".format(name)

    stop_words = template.stop_words
    if template.replace_eos:
        if not stop_words:
            raise ValueError("Stop words are required to replace the EOS token.")

        tokenizer.tokenizer.eos_token = stop_words[0]
        stop_words = stop_words[1:]
        logger.info("Replace eos token: {}".format(tokenizer.eos_token))

    if stop_words:
        tokenizer.add_special_tokens(
            dict(additional_special_tokens=stop_words),
            replace_additional_special_tokens=False,
        )
        logger.info("Add {} to stop words.".format(",".join(stop_words)))

    return template


register_template(
    name="alpaca",
    prefix=["{{system}}"],
    prompt=["### Instruction:\n{{query}}\n\n### Response:\n"],
    system=(
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request."
    ),
    sep=["\n\n"],
)


register_template(
    name="aquila",
    prefix=["{{system}}"],
    prompt=["Human: {{query}}###Assistant:"],
    system=(
        "A chat between a curious human and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the human's questions."
    ),
    sep=["###"],
    stop_words=["</s>"],
    efficient_eos=True,
)


register_template(
    name="baichuan",
    prefix=["{{system}}"],
    prompt=[
        {"token": "<reserved_102>"},  # user token
        "{{query}}",
        {"token": "<reserved_103>"},  # assistant token
    ],
    system="",
    sep=[],
    efficient_eos=True,
)


register_template(
    name="baichuan2",
    prefix=["{{system}}"],
    prompt=[
        {"token": "<reserved_106>"},  # user token
        "{{query}}",
        {"token": "<reserved_107>"},  # assistant token
    ],
    system="",
    sep=[],
    efficient_eos=True,
)


register_template(
    name="belle",
    prefix=["{{system}}"],
    prompt=["Human: {{query}}\n\nBelle: "],
    system="",
    sep=["\n\n"],
)


register_template(
    name="bluelm",
    prefix=["{{system}}"],
    prompt=[{"token": "[|Human|]:"}, "{{query}}", {"token": "[|AI|]:"}],
    system="",
    sep=[],
)


register_template(
    name="chatglm2",
    prefix=[{"token": "[gMASK]"}, {"token": "sop"}, "{{system}}"],
    prompt=["[Round {{idx}}]\n\n问：{{query}}\n\n答："],
    system="",
    sep=["\n\n"],
    efficient_eos=True,
)


register_template(
    name="chatglm3",
    prefix=[
        {"token": "[gMASK]"},
        {"token": "sop"},
        {"token": "<|system|>"},
        "\n",
        "{{system}}",
    ],
    prompt=[
        {"token": "<|user|>"},
        "\n",
        "{{query}}",
        {"token": "<|assistant|>"},
        "\n",  # add an extra newline to avoid error in ChatGLM's process_response method
    ],
    system=(
        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
        "Follow the user's instructions carefully. Respond using markdown."
    ),
    sep=[],
    stop_words=["<|user|>", "<|observation|>"],
    efficient_eos=True,
)


register_template(
    name="chatglm3_raw",  # the raw template for tool tuning
    prefix=[
        {"token": "[gMASK]"},
        {"token": "sop"},
        {"token": "<|system|>"},
        "\n",
        "{{system}}",
    ],
    prompt=[{"token": "<|user|>"}, "\n", "{{query}}", {"token": "<|assistant|>"}],
    system=(
        "You are ChatGLM3, a large language model trained by Zhipu.AI. "
        "Follow the user's instructions carefully. Respond using markdown."
    ),
    sep=[],
    stop_words=["<|user|>", "<|observation|>"],
    efficient_eos=True,
)


register_template(
    name="codegeex2",
    prefix=[{"token": "[gMASK]"}, {"token": "sop"}, "{{system}}"],
    prompt=["{{query}}"],
    system="",
    sep=[],
)


register_template(
    name="deepseek",
    prefix=["{{system}}"],
    prompt=["User: {{query}}\n\nAssistant:"],
    system="",
    sep=[],
)


register_template(
    name="deepseekcoder",
    prefix=["{{system}}"],
    prompt=["### Instruction:\n{{query}}\n### Response:\n"],
    system=(
        "You are an AI programming assistant, utilizing the Deepseek Coder model, "
        "developed by Deepseek Company, and you only answer questions related to computer science. "
        "For politically sensitive questions, security and privacy issues, "
        "and other non-computer science questions, you will refuse to answer\n"
    ),
    sep=["\n", {"token": "<|EOT|>"}, "\n"],
    stop_words=["<|EOT|>"],
    efficient_eos=True,
)


register_template(
    name="default",
    prefix=["{{system}}"],
    prompt=["Human: {{query}}\nAssistant:"],
    system=(
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
    ),
    sep=["\n"],
)


register_template(
    name="falcon",
    prefix=["{{system}}"],
    prompt=["User: {{query}}\nFalcon:"],
    system="",
    sep=["\n"],
    efficient_eos=True,
)


register_template(
    name="internlm",
    prefix=["{{system}}"],
    prompt=["<|User|>:{{query}}", {"token": "<eoh>"}, "\n<|Bot|>:"],
    system="",
    sep=[{"token": "<eoa>"}, "\n"],
    stop_words=["<eoa>"],
    efficient_eos=True,
)


register_template(
    name="llama2",
    prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"],
    prompt=["[INST] {{query}} [/INST]"],
    system=(
        "You are a helpful, respectful and honest assistant. "
        "Always answer as helpfully as possible, while being safe. "
        "Your answers should not include any harmful, unethical, "
        "racist, sexist, toxic, dangerous, or illegal content. "
        "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
        "If a question does not make any sense, or is not factually coherent, "
        "explain why instead of answering something not correct. "
        "If you don't know the answer to a question, please don't share false information."
    ),
    sep=[],
)


register_template(
    name="llama2_zh",
    prefix=["<<SYS>>\n{{system}}\n<</SYS>>\n\n"],
    prompt=["[INST] {{query}} [/INST]"],
    system="You are a helpful assistant. 你是一个乐于助人的助手。",
    sep=[],
)


register_template(
    name="mistral",
    prefix=["{{system}}"],
    prompt=["[INST] {{query}} [/INST]"],
    system="",
    sep=[],
)


register_template(
    name="openchat",
    prefix=["{{system}}"],
    prompt=[
        "GPT4 Correct User: {{query}}",
        {"token": "<|end_of_turn|>"},
        "GPT4 Correct Assistant:",
    ],
    system="",
    sep=[{"token": "<|end_of_turn|>"}],
    stop_words=["<|end_of_turn|>"],
    efficient_eos=True,
)


register_template(
    name="qwen",
    prefix=["<|im_start|>system\n{{system}}<|im_end|>"],
    prompt=["<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"],
    system="You are a helpful assistant.",
    sep=["\n"],
    stop_words=["<|im_end|>"],
    replace_eos=True,
)


register_template(
    name="starchat",
    prefix=[
        {"token": "<|system|>"},
        "\n{{system}}",
    ],
    prompt=[
        {"token": "<|user|>"},
        "\n{{query}}",
        {"token": "<|end|>"},
        "\n",
        {"token": "<|assistant|>"},
    ],
    system="",
    sep=[{"token": "<|end|>"}, "\n"],
    stop_words=["<|end|>"],
    efficient_eos=True,
)


register_template(
    name="vanilla",
    prefix=[],
    prompt=["{{query}}"],
    system="",
    sep=[],
    use_history=False,
)


register_template(
    name="vicuna",
    prefix=["{{system}}"],
    prompt=["USER: {{query}} ASSISTANT:"],
    system=(
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
    ),
    sep=[],
)


register_template(
    name="xuanyuan",
    prefix=["{{system}}"],
    prompt=["Human: {{query}} Assistant:"],
    system=(
        "以下是用户和人工智能助手之间的对话。用户以Human开头，人工智能助手以Assistant开头，"
        "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答，并且总是拒绝参与与不道德、"
        "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
    ),
    sep=[],
)


register_template(
    name="xverse",
    prefix=["{{system}}"],
    prompt=["Human: {{query}}\n\nAssistant: "],
    system="",
    sep=[],
)


register_template(
    name="yayi",
    prefix=[{"token": "<|System|>"}, ":\n{{system}}"],
    prompt=[{"token": "<|Human|>"}, ":\n{{query}}\n\n", {"token": "<|YaYi|>"}, ":"],
    system=(
        "You are a helpful, respectful and honest assistant named YaYi "
        "developed by Beijing Wenge Technology Co.,Ltd. "
        "Always answer as helpfully as possible, while being safe.  "
        "Your answers should not include any harmful, unethical, "
        "racist, sexist, toxic, dangerous, or illegal content. "
        "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
        "If a question does not make any sense, or is not factually coherent, "
        "explain why instead of answering something not correct. "
        "If you don't know the answer to a question, please don't share false information."
    ),
    sep=["\n\n"],
    stop_words=["<|End|>"],
)


register_template(
    name="yi",
    prefix=["{{system}}"],
    prompt=["<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"],
    system="",
    sep=["\n"],
    stop_words=["<|im_end|>"],
    replace_eos=True,
)


register_template(
    name="yuan",
    prefix=["{{system}}"],
    prompt=["{{query}}", {"token": "<sep>"}],
    system="",
    sep=["\n"],
    stop_words=["<eod>"],
    replace_eos=True,
)


register_template(
    name="zephyr",
    prefix=[{"token": "<|system|>"}, "\n{{system}}", {"token": "</s>"}],
    prompt=[
        {"token": "<|user|>"},
        "\n{{query}}",
        {"token": "</s>"},
        {"token": "<|assistant|>"},
    ],
    system="You are a friendly chatbot who always responds in the style of a pirate",
    sep=[],
)


register_template(
    name="ziya",
    prefix=["{{system}}"],
    prompt=[{"token": "<human>"}, ":{{query}}\n", {"token": "<bot>"}, ":"],
    system="",
    sep=["\n"],
)

register_template(
    name="internlm2",
    prefix=["<s>[UNUSED_TOKEN_146]system\n{{system}}[UNUSED_TOKEN_145]\n"],
    prompt=[
        "[UNUSED_TOKEN_146]user\n{{query}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n",
    ],
    system="You are an AI assistant whose name is InternLM (书生·浦语).\n"
    "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
    "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
    sep=[{"token": "[UNUSED_TOKEN_145]"}, "\n"],
    stop_words=[
        "[UNUSED_TOKEN_145]",
        "</s>",
    ],
    efficient_eos=True,
)

register_template(
    name="mixtral",
    prefix=[
        "<s>",
    ],
    prompt=["[INST]{{query}}[/INST]"],
    stop_words=[
        "</s>",
    ],
    sep=[
        "</s>",
    ],
    system="",
)
